{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 12_Train_final_model_RotatE\n",
    "#\n",
    "# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 3, 2023\n",
    "# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 11, 2023\n",
    "#\n",
    "# 该脚本展示了如何在 DRKG 上训练最终模型 (RotatE).\n",
    "#\n",
    "# 需要的包:\n",
    "#          torch\n",
    "#          dgl, version: 0.4.3\n",
    "#          dglke\n",
    "#\n",
    "# 需要的文件:\n",
    "#          ../../data/drkg/drkg.tsv\n",
    "#\n",
    "# 源教程链接: https://github.com/gnn4dr/DRKG/blob/master/embedding_analysis/Train_embeddings.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training DRKG Using RotatE\n",
    "\n",
    "这个 notebook 展示了如何在 DRKG 上训练最终模型 (RotatE)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入需要的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们能使用 DGL-KE 命令训练 RotatE 模型, 关于如何使用 DGL-KE 的更多信息请参考 https://github.com/awslabs/dgl-ke.\n",
    "\n",
    "这里我们使用两个 GPU 训练模型."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- batch_size: **4096**\n",
    "\n",
    "- neg_sample_size: **256**\n",
    "\n",
    "- hidden_dim: **200**\n",
    "\n",
    "- gamma: 6, 12, **18**\n",
    "\n",
    "- lr: 0.01, **0.05**, 0.1\n",
    "\n",
    "注意: 下面两个参数和网格搜索时有些许不同:\n",
    "\n",
    "- max_step: **60000** -> **100000**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading train triples....\n",
      "Finished. Read 5874261 train triples.\n",
      "|Train|: 5874261\n",
      "random partition 5874261 edges into 2 parts\n",
      "part 0 has 2937131 edges\n",
      "part 1 has 2937130 edges\n",
      "/home/luyanfeng/miniconda3/envs/drkg/lib/python3.8/site-packages/dgl/base.py:25: UserWarning: multigraph will be deprecated.DGL will treat all graphs as multigraph in the future.\n",
      "  warnings.warn(msg, warn_type)\n",
      "Total initialize time 14.715 seconds\n",
      "[proc 1][Train](20000/100000) average pos_loss: 0.40681921094208956\n",
      "[proc 1][Train](20000/100000) average neg_loss: 0.48848652160540224\n",
      "[proc 1][Train](20000/100000) average loss: 0.44765286628901957\n",
      "[proc 0][Train](20000/100000) average pos_loss: 0.40666280265301463\n",
      "[proc 1][Train](20000/100000) average regularization: 0.0005929242649363004\n",
      "[proc 1][Train] 20000 steps take 3214.189 seconds\n",
      "[proc 1]sample: 59.439, forward: 1244.045, backward: 76.657, update: 1833.772\n",
      "[proc 0][Train](20000/100000) average neg_loss: 0.4899400228448212\n",
      "[proc 0][Train](20000/100000) average loss: 0.4483014127209783\n",
      "[proc 0][Train](20000/100000) average regularization: 0.0005927958798518376\n",
      "[proc 0][Train] 20000 steps take 3214.211 seconds\n",
      "[proc 0]sample: 59.942, forward: 1233.702, backward: 75.702, update: 1834.150\n",
      "[proc 1][Train](40000/100000) average pos_loss: 0.37042098518908023\n",
      "[proc 1][Train](40000/100000) average neg_loss: 0.46151476120352747\n",
      "[proc 0][Train](40000/100000) average pos_loss: 0.3702409432798624\n",
      "[proc 1][Train](40000/100000) average loss: 0.41596787329018114\n",
      "[proc 1][Train](40000/100000) average regularization: 0.0006071507920132717\n",
      "[proc 1][Train] 20000 steps take 3214.544 seconds\n",
      "[proc 1]sample: 58.714, forward: 1246.361, backward: 75.572, update: 1833.687\n",
      "[proc 0][Train](40000/100000) average neg_loss: 0.4620280939161778\n",
      "[proc 0][Train](40000/100000) average loss: 0.4161345186293125\n",
      "[proc 0][Train](40000/100000) average regularization: 0.0006070671200897777\n",
      "[proc 0][Train] 20000 steps take 3214.539 seconds\n",
      "[proc 0]sample: 59.690, forward: 1235.138, backward: 76.603, update: 1831.749\n",
      "[proc 1][Train](60000/100000) average pos_loss: 0.36835208870470526\n",
      "[proc 1][Train](60000/100000) average neg_loss: 0.4579752182945609\n",
      "[proc 1][Train](60000/100000) average loss: 0.41316365341991185\n",
      "[proc 1][Train](60000/100000) average regularization: 0.0006119719384383643\n",
      "[proc 1][Train] 20000 steps take 3215.447 seconds\n",
      "[proc 1]sample: 58.564, forward: 1246.583, backward: 76.388, update: 1833.116\n",
      "[proc 0][Train](60000/100000) average pos_loss: 0.3682194778814912\n",
      "[proc 0][Train](60000/100000) average neg_loss: 0.45825662498921155\n",
      "[proc 0][Train](60000/100000) average loss: 0.4132380513638258\n",
      "[proc 0][Train](60000/100000) average regularization: 0.0006118914026214042\n",
      "[proc 0][Train] 20000 steps take 3215.456 seconds\n",
      "[proc 0]sample: 59.858, forward: 1236.015, backward: 77.607, update: 1832.277\n",
      "[proc 1][Train](80000/100000) average pos_loss: 0.36727531443089245\n",
      "[proc 0][Train](80000/100000) average pos_loss: 0.3672783232614398\n",
      "[proc 1][Train](80000/100000) average neg_loss: 0.4562994497969747\n",
      "[proc 1][Train](80000/100000) average loss: 0.41178738210350274\n",
      "[proc 0][Train](80000/100000) average neg_loss: 0.4561406337186694\n",
      "[proc 1][Train](80000/100000) average regularization: 0.0006149285451072501\n",
      "[proc 0][Train](80000/100000) average loss: 0.41170947849452494\n",
      "[proc 1][Train] 20000 steps take 3219.608 seconds\n",
      "[proc 1]sample: 58.992, forward: 1251.124, backward: 75.310, update: 1833.960\n",
      "[proc 0][Train](80000/100000) average regularization: 0.0006148581685760291\n",
      "[proc 0][Train] 20000 steps take 3219.589 seconds\n",
      "[proc 0]sample: 60.080, forward: 1240.102, backward: 77.444, update: 1833.515\n",
      "[proc 1][Train](100000/100000) average pos_loss: 0.36664258914738895\n",
      "[proc 1][Train](100000/100000) average neg_loss: 0.45499095550477503\n",
      "[proc 0][Train](100000/100000) average pos_loss: 0.3666928806349635\n",
      "[proc 1][Train](100000/100000) average loss: 0.41081677231639624\n",
      "[proc 1][Train](100000/100000) average regularization: 0.0006169096214638558\n",
      "[proc 1][Train] 20000 steps take 3222.328 seconds\n",
      "[proc 1]sample: 58.682, forward: 1253.753, backward: 74.797, update: 1834.894\n",
      "proc 1 takes 16086.116 seconds\n",
      "[proc 0][Train](100000/100000) average neg_loss: 0.4550282849103212\n",
      "[proc 0][Train](100000/100000) average loss: 0.4108605828806758\n",
      "[proc 0][Train](100000/100000) average regularization: 0.000616843212855747\n",
      "[proc 0][Train] 20000 steps take 3222.328 seconds\n",
      "[proc 0]sample: 60.482, forward: 1239.720, backward: 77.813, update: 1832.804\n",
      "proc 0 takes 16086.124 seconds\n",
      "training takes 16086.556456804276 seconds\n",
      "Save model to ckpts/RotatE_All_DRKG_0\n"
     ]
    }
   ],
   "source": [
    "!DGLBACKEND=pytorch dglke_train --dataset All_DRKG --data_path ../../data/drkg \\\n",
    "--data_files drkg.tsv --format 'raw_udd_hrt' \\\n",
    "--model_name RotatE -de \\\n",
    "--batch_size 4096 --neg_sample_size 256 --hidden_dim 200 \\\n",
    "--gamma 18.0 --lr 0.05 --max_step 100000 -adv --regularization_coef 1.00E-07 \\\n",
    "--gpu 0 1 --num_proc 2 --mix_cpu_gpu --async_update --force_sync_interval 1000 \\\n",
    "--log_interval 20000 --num_thread 32"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
